import json
import os
import torch

MODEL_PATH = 'model/saved_model/model'
MODEL_PATH_PLUS = 'model/saved_model/model_plus'


def save_log(data):
    model = data['model']
    if data['epoch'] % 10 == 1:
        torch.save(model.state_dict(), MODEL_PATH+str(data['epoch'])+'.pt')


def save_log_plus(data):
    model = data['model']
    if data['epoch'] % 10 == 1:
        torch.save(model.state_dict(), MODEL_PATH_PLUS+str(data['epoch'])+'.pt')
